function [u,v,energies] = regularized_NLTV(in,nbIter,radius,...
    mu,lambda,gama, hsigma, init, application)

sizeB = 2*radius+1;
size_in = size(in);

switch application.name
    case 'zooming'
        size_in = size(in);
        % because in is the downsample result
        size_in = size_in*application.scale;
end


if isfield(application,'init_u')
    init_u = application.init_u;
else
    init_u = in;
end

%% initialization of u and v
if init == 3
    v = weight_v_square6(init_u,sizeB,hsigma);
elseif init == 9
    v = weight_v_square9(init_u,sizeB,hsigma,application.miss_data);
elseif init == 33
    v = weight_v_square610(init_u,sizeB,hsigma,application.miss_data);
elseif init == 16
    v = weight_v_square16(init_u,sizeB,hsigma);
end

u = init_u;

energies = getEnergy(u,v,in,mu,lambda,gama,application);

disp(['gama', num2str(gama),'   mu',num2str(mu)]);

%% PALM iteration in u
for it=1:nbIter
    
    Lv = ( sqrt(2) * sizeB /mu) * 1.001;
    u = u - (1/Lv) * grad_NLTV_u(u,v,mu);
    u = prox_data_term(u,in,Lv,lambda,application);
    
    %% PALM iteration in v
    tmp = grad_NLTV_v(u,v,mu)+grad_R(v,gama,size_in);
    Lu = compute_Lu(u,v,mu,gama) * 1.001;
    tmp=tmp/Lu;
    v = v - tmp ;
    v = prox_v_term(v);
    
    energy = getEnergy(u,v,in,mu,lambda,gama,application);
    
    energies = [energies; energy];
end
end


%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

function out = R(v,size_in)
%% regularity term for v
sizeQ=size(v,2);
w=B(v,sizeQ,size_in);
if norm(imag(w)) ~=0
    disp('ERREUR DANS R (i.e. B)')
end;

out = sum(sum( conj(w) .* w ));
end

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

function out = prox_data_term(u,in,Lv,lambda,application)
%% prox operatro for the data fidelity term.

switch application.name
    case 'denoising'
        if application.data_term == 1
            out = (2*lambda*in + Lv*u)/(2*lambda+Lv);
            
        elseif application.data_term == 2
            %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
            % projection onto the ball || u-in || \leq \lambda
            
            size_in = size(u);
            tau = lambda * lambda *size_in(1)* size_in(2);
            tmp = u-in;
            norm_r=sum(sum(tmp.*tmp));
            if norm_r>tau
                step = sqrt(tau/norm_r);
                out = in + step*(u-in);
            else
                out=u;
            end
        end
    case 'inpainting'
        idx1 = (application.miss_data == 1);
        out = (2*lambda*in + Lv*u)/(2*lambda+Lv).*idx1 ...
            + u.*(~idx1);     % missing part (inpainting domain)
    case 'zooming'
        lam2HH = application.lam2HH_matrix;
        lam2HT = application.lam2HT_matrix;
        N_zooming = size(lam2HH,1);
        A_cg = sparse(lam2HH + Lv*speye(N_zooming));
        b_cg = lam2HT*in(:) + Lv*u(:);
        [out, tolcg]= pcg(A_cg,b_cg);
        res = b_cg - A_cg*out;
        out = reshape(out,size(u));
    otherwise
        fprintf('Error, no such application')
end
end

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

function out = NLTV(u,v,mu)
%% Computes NLTV

mu2 = mu*mu;
Du=D(u, v , size(u) );

if norm(imag(Du)) ~=0
    disp('ERREUR DANS Du')
end;

norm_Du = sum( Du.* conj(Du) , 2 ) ;
tmp = ( real(norm_Du) > mu2 );
out = sum( tmp .* (real(sqrt(norm_Du)) - mu/2) + (1-tmp) .* norm_Du/(2*mu) );

if norm(imag(out)) ~=0
    disp('ERREUR DANS NLTV')
end;
end

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

function out = grad_NLTV_u(u,v,mu)
%% Computes the gradient (in u) of the function NLTV

Du=D(u, v , size(u) );
norm_Du = sqrt(sum( Du.*Du , 2 ));

tmp1= max(norm_Du,mu) * ones(1,size(v,2));
w_star = Du ./ tmp1;

out = D(w_star,v,size(u));
end

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function out = D(in,v,size_in)
%% Applies the operator D or its adjoint
%% The operator D is used to compute  grad_NLTV_u
%% it corresponds to the partial derivatives

sizeB=sqrt( size(v,2)+1 );

if size(in)==size(v)
    %% apply adjoint of D to 'in'
    tmp =  sqrt(v).*in ;
    tmp3=[];
    for q = 1:size(v,2)+1
        if 2*q <= size(v,2)
            tmp1 = reshape(tmp(:,q),size_in);
        else
            tmp1 = reshape(tmp(:,q-1),size_in);
        end
        
        tmp2 = translate_u(tmp1,q,sizeB,0);
        
        if ~isempty(tmp2)
            tmp3 = [tmp3 , tmp2(:)];
        end
    end
    out = reshape( sum(tmp,2) - sum(tmp3,2) , size_in ) ;
    
else
    %% apply D to 'in'
    
    tmp1 = in(:);
    
    out=[];
    
    for q = 1:size(v,2)+1
        tmp = translate_u(in,q,sizeB,1);
        
        if ~isempty(tmp)
            out = [out , tmp1 - tmp(:)];
        end
    end
    
    out = sqrt( max(v,0) ) .* out;
    
end
end

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function out = translate_u(in,index,sizeB,direction)
%% translates an image u
%% The translation vector depends on "index" and "sizeB"
%% For instance, for sizeB=3 the location of "index" in
%% the following table provides the vector (the center of the table is (0,0))
%%
%%      1  | 2  |  3
%%      4  | 5  |  6
%%      7  | 8  |  9
%%


radius = floor(sizeB/2);
tx = floor((index-1)/sizeB) ;
ty = index-1 - tx*sizeB - radius;
tx = tx-radius;

if direction == 0
    tx=-tx;
    ty=-ty;
end

[sx,sy]=size(in);

%% translate en x
if tx<0
    out = [ in(sx+tx+1:sx,:) ; in(1:sx+tx,:) ];
elseif tx>0
    out = [ in(1+tx:sx,:) ; in(1:tx,:) ];
elseif tx==0 & ty~=0
    out = in;
else
    out=[];
end

%% translate en y
if ty<0
    out = [ out(:,sy+ty+1:sy) , out(:,1:sy+ty) ];
elseif ty>0
    out = [ out(:,1+ty:sy) , out(:,1:ty) ];
end
end


%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

function out = prox_v_term(v)
%%
%% It is a simple projection onto U

out = project_v(v);

end

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

function out = grad_R(v,gama,size_in)
%% compute gradient of the regularity gama*R, the term regularizing v

sizeQ = size(v,2);
out = 2*gama*B( B( v ,sizeQ,size_in) ,sizeQ,size_in);

end

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

function out = B(v,sizeQ,size_in)
%% Applies B or its adjoint
%% B is used to compute the descent direction in prox_v_term
%% We consider 2 neihbors in the regularization energy for v, called R

if size(v,2)==sizeQ
    %% compute the operator B
    
    v1=[];
    v2=[];
    for q = 1:sizeQ
        tmp = reshape( v(:,q) , size_in ) ;
        
        tmp1 = translate_u(tmp,6,3,1); %% translation by one horizontal unit
        tmp2 = translate_u(tmp,8,3,1); %% translation by one vertical unit
        
        v1=[v1, tmp1(:)];
        v2=[v2, tmp2(:)];
    end
    
    out = [v-v1, v-v2 ];
else
    %% compute the adjoint of the operator B
    
    v1=v(:,1:sizeQ); %% correspond to q'=(0,1)
    v2=v(:,sizeQ+1:2*sizeQ); %% correspond to q'=(1,0)
    out=[];
    for q = 1:sizeQ
        tmp1 = reshape( v1(:,q) , size_in ) ;
        tmp2 = reshape( v2(:,q) , size_in ) ;
        tmp1 = translate_u(tmp1,6,3,0); %% translation by one horizontal unit (in opposite direction)
        tmp2 = translate_u(tmp2,8,3,0); %% translation by one vertical unit (in opposite direction)
        out=[out,( v1(:,q) - tmp1(:) ) + ( v2(:,q) - tmp2(:) ) ];
    end
    
end
end

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

function out = compute_Lu(u,v,mu,gama)
%% Computes the Lipschitz constants of \nabla_{v^p} TV_v for every "p"
%% It is used in the PALM iteration

%% Zhi
card_V=2;
out = sqrt(2)*( card_V*6*gama) + 0.001;%% True solution
%%    out = max(max(tmp1))  + 0.001;%% Test for grad_tv_v
end

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

function out = grad_NLTV_v(u,v,mu)
%% Computes the gradient of NLTV ith regard to v

size_in=size(u);
sizeB = sqrt( size(v,2)+1 );
Av = real( sqrt( A(u, v ,size_in,sizeB) ));
Avp=max(Av,mu);
tmp=(Av >= mu);
w_star = tmp./Avp + (1-tmp)/mu ;
out = A(u,w_star/2,size_in,sizeB);
end

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

function out = A(u,in,size_in,sizeB)
%% applies the operator A or its adjoint
%% The operator A is used to compute the  grad_NLTV_v

tmp1 = u(:);
tmp=[];

for q = 1:sizeB*sizeB
    tmp2 = translate_u(u,q,sizeB,1);
    
    if ~isempty(tmp2)
        tmp = [tmp , tmp1 - tmp2(:)];
    end
end

tmp = tmp .* tmp;

if size(u)==size(in)
    %% apply adjoint of A to 'in'
    out = (in(:)*ones(1,size(tmp,2))) .* tmp;
else
    %% apply A to 'in'
    out = reshape( sum(in.*tmp,2) , size_in );
end
end

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

function out = project_v(in)
%% alternate projection algorithm computing the projection onto
%% the constraint set for v.

[sizeU, sizeQ] = size(in);
out = in+0.000001 * randn(size(in));
sorted_out = sort(out,2, 'descend');

sum_of_thresholded = [zeros([sizeU,1]), cumsum(sorted_out,2)] - (ones([sizeU,1])* (0:sizeQ)).*[sorted_out,sorted_out(:,end)];
sum_of_thresholded = sum_of_thresholded(:,1:sizeQ);

[i,tmp] = max( (sum_of_thresholded<1) .*  (ones([sizeU,1])* (1:sizeQ)),[] ,2 );

mu_small = (sum(out,2) - 1)/sizeQ;

indexes = (i-1)*sizeU + (1:sizeU)';
sum_at_indexes = sum_of_thresholded(indexes);

next_indexes = (i<sizeQ) .* (indexes+sizeU-1)+1;
sum_at_next_indexes = (i<sizeQ) .* sum_of_thresholded(next_indexes) + (i==sizeQ) .* (sum_at_indexes+1);
alpha = (1-sum_at_next_indexes) ./ (sum_at_indexes-sum_at_next_indexes);
mu_large = (alpha.*sorted_out(indexes) + (1-alpha).*sorted_out(next_indexes));

mu = (i==sizeQ) .* mu_small + (i<sizeQ) .* mu_large;

% soft threshold
out = max(0,out-mu*ones([1 sizeQ]));
end

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

function tmp1 = getEnergy(u,v,in,mu,lambda,gama,application)
% value of the energy function
switch application.name
    case 'zooming'
        H = application.H_matrix;
        tmp = H*u(:) - in(:);
    case 'inpainting'
        tmp = u-in;
        idx1 = (application.miss_data == 1);
        tmp = tmp(idx1);
    case 'denoising'
        tmp = u-in;
end
tmp1 = [gama*R(v,size(u)),NLTV(u,v,mu), lambda * sum(sum(tmp.*tmp))];
end

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

function out = weight_v_square610(in,sizeB,hsigma,miss_data)
v_col = sizeB*sizeB-1;
h = (hsigma)^2;

in = in./max(in(:));
tmp1 = in(:);
out=[];

in_cooked = in;
ind_miss = (miss_data==0);
in_cooked(ind_miss) = 1000;

for q = 1:v_col+1
    tmp = translate_u(in_cooked,q,sizeB,1);
    if ~isempty(tmp)
        
        ind_influ = (tmp == 1000);
        diff_tmp = (tmp1 - tmp(:));
        % all influenced parts set to 0s
        ind_influ = (~ind_influ).*miss_data;
        ind_influ = ind_influ(:);
        % change the parts influenced by missing domains to 0
        tmp_vv = (diff_tmp.*ind_influ).^2/(hsigma*hsigma);
        
        ind_vv = tmp_vv<29;
        tmp_out = exp(-tmp_vv.*ind_vv).*ind_vv.*ind_influ;
        out = [out  tmp_out];
        
    end
end

indm = find(miss_data==0);
out(indm,:)=1;

tmp = sum(out,2);
out = out./repmat(tmp,1,v_col);
end

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

function out = weight_v_square9(in,sizeB,hsigma,miss_data)
v_col = sizeB*sizeB-1;
h = (hsigma)^2;

tmp1 = in(:);
out=[];

for q = 1:v_col+1
    tmp = translate_u(in,q,sizeB,1);
    
    if ~isempty(tmp)
        out = [out  (tmp1 - tmp(:))];
    end
end
ind0 = (abs(out)<eps);
out(ind0) = 1; out(~ind0)=0;
indm = find(miss_data==0);
out(indm,:)=1;
tmp = sum(out,2);
out = out./repmat(tmp,1,v_col);
end

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

function out = weight_v_square6(in,sizeB,hsigma)

v_col = sizeB*sizeB-1;
h = (hsigma)^2;

in = in./max(in(:));
tmp1 = in(:);
out=[];

for q = 1:v_col+1
    tmp = translate_u(in,q,sizeB,1);
    
    if ~isempty(tmp)
        out = [out  (tmp1 - tmp(:))];
    end
    
end

vv = out.^2/(hsigma*hsigma);

% avoid small values   exp(-29) = 2.5437e-13
ind_vv = vv<29;
out = exp(-vv.*ind_vv).*ind_vv;
%

% try to find those rows only contain 0s
ind0 = (abs(out)<eps);
rind0 = find(sum(ind0,2) == size(out,2));
out(rind0,:)=1;

tmp = sum(out,2);
if ~isempty(find(abs(tmp)<eps))
    disp('divide by zero in intial v')
end
out = out./repmat(tmp,1,v_col);
end

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

function out = weight_v_square16(in,sizeB,hsigma)

v_col = sizeB*sizeB-1;
h = (hsigma)^2;
if (max(in(:))~=0)
    in = in./max(in(:));
end
tmp1 = in(:);
out=[];

for q = 1:v_col+1
    tmp = translate_u(in,q,sizeB,1);
    
    if ~isempty(tmp)
        out = [out  (tmp1 - tmp(:))];
    end
end

ind0 = (abs(out)<eps);
rind0 = find(sum(ind0,2) == size(out,2));
out(rind0,:)=0.2;

vv = out.^2/(hsigma*hsigma);

% avoid small values   exp(-29) = 2.5437e-13
ind_vv = vv<29;
out = exp(-vv.*ind_vv).*ind_vv;
%
tmp = sum(out,2);
if ~isempty(find(abs(tmp)<eps))
    disp('divide by zero in intial v')
end
out = out./repmat(tmp,1,v_col);
end